1 Cross validation of random forest models, and permutation testing for model accuracy metrics

1.1 Load packages

require(tidyverse)
require(randomForest)
require(caret)
require(doParallel)
require(pROC)
require(plotROC)
require(MLmetrics)

1.2 Define the list of parameters for each model, based on results of ‘rfGridSearchCV.rmd’ for each dataset:

param_list = list(
  "FD" = c("mtry"=2350,"ntree"=1501,"nodesize"= 7),
  "FC" = c("mtry"=2350,"ntree"=1501,"nodesize"= 7),
  "FDPCA" = c("mtry"=30,"ntree"=801,"nodesize"= 7),
  "FCPCA" = c("mtry"=30,"ntree"=801,"nodesize"= 7),
  "Str" = c("mtry"=3,"ntree"=851,"nodesize"= 5),
  "Str_noWBV" = c("mtry"=5,"ntree"=851,"nodesize"= 5),
  "StrPCA" = c("mtry"=19,"ntree"=2001,"nodesize"= 12),
  "StrPCA_noWBV" = c("mtry"=17,"ntree"=851,"nodesize"= 5)
)
as.data.frame(param_list)

2 Fit the “baseline” models to generate OOB scores and variable importances.

This differs from the later models in that it conducts bagging, but not cross validation.

set.seed(111) # for reproducibility.
bestForest.list = mclapply(X = names(data_list), FUN = function(x) {
  dat = data_list[[x]]; params = param_list[[x]]
  # pre-proc
  dat = na.omit(
    select(dat, -one_of("age","IDENT_SUBID","SUBJECTID_long",
                        "wave_to_pull", "cbcl_totprob_t")))
  dat$GROUP = factor(dat$GROUP)
  model = randomForest(GROUP ~., data = dat,
                       mtry = params[1],
                       ntree = params[2],
                       nodesize = params[3],
                       strata = dat$GROUP,
                       sampsize = rep(sum(dat$GROUP=="PI"), times = 2),
                       importance = TRUE)
}, mc.cores = (detectCores() - 1))
names(bestForest.list) = names(data_list)
saveRDS(bestForest.list, "../../../output/RF/bestGridTunedRF.rds")

2.0.1 compile the OOB scores

OOBs = sapply(bestForest.list, function(x) {
  sum(diag(x$confusion))/sum(x$confusion)
})

3 Run the cross-validation/permutation testing algorithm on all datasets

permutationResults = mclapply(X = names(data_list), mc.cores = detectCores() - 1, 
  FUN = function(x) {
    dat = data_list[[x]]; params = param_list[[x]]
    # pre-proc
    dat = na.omit(
      select(dat, -one_of("age","IDENT_SUBID","SUBJECTID_long",
                           "wave_to_pull", "cbcl_totprob_t")))
    #train test split
    dat$GROUP = factor(dat$GROUP)
    index_pi = which(dat$GROUP == "PI"); index_comp = which(dat$GROUP !="PI")
    train_i = replicate(n = 1000, simplify = FALSE, 
                        expr=c(sample(index_pi,size=round(length(index_pi)*.75)),
                               sample(index_comp,size=round(length(index_pi)*.75))))
    crossVal.res = mclapply(X = train_i, mc.cores = detectCores() - 1, 
      FUN = function(y, dat = dat, params = params) {
        training = dat[y, ]; test = dat[-y, ]
        fit = randomForest(GROUP ~ ., data = training,
                           mtry = params[1],
                           ntree = params[2],
                           nodesize = params[3],
                           strata = training$GROUP,
                           sampsize = c(round(length(index_pi)*.75),
                                        round(length(index_pi)*.75))
        )
        # get the predicted, actual, accuracy, and null dist of 1000 permuted accuracies
        pred = factor(predict(fit, newdata = test), levels = c("COMP", "PI"))
        actual = factor(test$GROUP, levels = c("COMP","PI"))
        prob1 = predict(fit, newdata = test, type="prob")[,"PI"]
        predActual.dat = as.data.frame(
          list("IDENT_SUBID"=rownames(test), "pred"=pred, "actual"=actual, "proba.pi"=prob1)
        )
        crossValAccuracy = confusionMatrix(pred, actual)$overall["Accuracy"]
        nullDistr = as.data.frame(list("nullAcc" = replicate(simplify = TRUE, n = 1000,
          expr = confusionMatrix(pred, sample(actual))$overall["Accuracy"])
        ))
        results= list(
          "predActual.dat" = predActual.dat,
          "CV_Accuracy"=crossValAccuracy,
          "nullDistrAcc.dat" = nullDistr
        )
        return(results)
    })
    return(crossVal.res)
  }
)

3.1 Compile results:

# extract the accuracies
accuracies = NULL
for (i in 1:length(permutationResults)) {
  tmp.Acc = vector(length = 1000, mode = "double")
  for (j in 1:1000) {
    tmp.Acc[j] = permutationResults[[i]][[j]]$CV_Accuracy
  }
  accuracies[[i]] = data.frame(estimate = mean(tmp.Acc), SD = sd(tmp.Acc))
  rm(list="tmp.Acc")
}
names(accuracies) = names(data_list)

# extract the null distributions
nullDistr = lapply(permutationResults, function(x) {
  tmp = lapply(x, function(y) {
    sort(y[[3]]$nullAcc) # extract the null predictions
  })
  return(rowMeans(Reduce("cbind", tmp)))
})
names(nullDistr) = names(data_list)

#create the average null distribution from all models and associated p vals
masterNull = rowMeans(Reduce("cbind",nullDistr))

# calculate permutation p-values
# p-value: (100% - percent of permuted values closer to chance than the observed)/100
perm.pval = lapply(names(data_list), function(x) {
  n = length(nullDistr[[x]])
  # comparing the mean of all 1000 test-set accuracies to the mean (sorted) null distribution
  (1 + sum(nullDistr[[x]] > mean(accuracies[[x]]$estimate)))/(1 + n)
})
names(perm.pval) = names(data_list)

#compared to common null
pval = lapply(names(data_list), function(x) {
  n = length(masterNull)
  # comparing the mean of all 1000 test-set accuracies to the mean (sorted) null distribution
  (1 + sum(masterNull > mean(accuracies[[x]]$estimate)))/(1 + n)
})
names(pval) = names(data_list)

3.2 Results:

as.data.frame(list("OOB_Acc" = round(OOBs,4),
                   "CV_Acc." = round(sapply(accuracies, function(x) x$estimate),4),
                   "CV_Acc_Var" = round(sapply(accuracies, function(x) x$SD),4),
                   "Null_Acc" = round(sapply(nullDistr, mean),4),
                   "Null_Var" = round(sapply(nullDistr, var),4),
                   "p" = round(unlist(perm.pval),3),
                   "common.p" = round(unlist(pval),3)), 
              row.names = names(data_list))

4 ROC curves for each:

# take the average prediction for each participant every time she was in the test set
aggregatePreds = mclapply(permutationResults, function(x) {
  d = Reduce("rbind", lapply(x, function(y) {y$predActual.dat}))
  out = d %>% group_by(IDENT_SUBID) %>%
    summarize(avgPred = mean(proba.pi))
  return(out)
}, mc.cores = (detectCores() - 1))
#merge in the factor labels, coding PI as 1
labels = readRDS("../../../wave1labels.rds")
aggregatePreds = lapply(aggregatePreds, function(x) {
  lbl = select(labels, IDENT_SUBID, GROUP) %>%
    mutate_at("GROUP", ~factor(ifelse(.=="PI", 1, 0)))
  merged = left_join(x, lbl)
})
names(aggregatePreds) = names(data_list)
#plot:
par(mfrow = c(2, 4))
#pdf(height = 4.5)
lapply(names(aggregatePreds), function(x) {
    plot(roc(predictor = aggregatePreds[[x]]$avgPred, response = aggregatePreds[[x]]$GROUP),
         xlim=c(1, 0), ylim = c(0, 1), main = x)
})

5 Visualize the distribution of the accuracies

5.1 Compile results into dataframes for plotting

## Add some labels and convert from wide to long format to plot distributions
dataNames = names(data_list)
dataType = c(rep(c("connectivity", "dissimilarity"),2), rep("structural",4))
rawAcc = lapply(permutationResults, function(x) {
  sapply(x, function(y) {
    y[[2]]
  })
})
perm_plt_data = Reduce("rbind", lapply(1:8, function(x) {
  n = length(rawAcc[[x]])
  data.frame("model" = rep(dataNames[x], times= (n+1000)),
             "dataType" = rep(dataType[x], times = (n+1000)),
             "Distribution" = c(rep("Test.Set.Repetitions",times=n), 
                                rep("Permuted.Null", times = 1000)),
             "Accuracy" = c(rawAcc[[x]], masterNull),
             stringsAsFactors = FALSE)
}))

# seprate functional and structural data for plotting
fMRI_plt_data = perm_plt_data %>% 
  filter(dataType != "structural")
StrMRI_plt_data = perm_plt_data %>% 
  filter(dataType == "structural")

5.2 Generate plots

fMRI_plt = ggplot(fMRI_plt_data, aes(Accuracy, fill = Distribution)) +
  geom_density(alpha = .3) +
  geom_vline( # calculate the means
    data = (
      data.frame("model"=dataNames, 
                 "avg" = sapply(rawAcc, mean)) %>%
        filter(grepl("^F", model))
    ),
    aes(xintercept = avg)) +
  facet_grid(~model) +
  ggtitle("Functional Data Model Accuracies and Permutation Test Results") +
  theme(panel.background = element_rect(fill="white"),
        plot.title = element_text(hjust = .5))

StrMRI_plt = ggplot(StrMRI_plt_data, aes(Accuracy, fill = Distribution)) +
  geom_density(alpha = .3) +
  geom_vline( # calculate the means
    data = (
      data.frame("model"=dataNames,
                 "avg" = sapply(rawAcc, mean)) %>%
        filter(!grepl("^F", model))
    ),
    aes(xintercept = avg)) +
  facet_grid(~model) +
  ggtitle("Structural Data Model Accuracies and Permutation Test Results") +
  theme(panel.background = element_rect(fill="white"),
        plot.title = element_text(hjust = .5))

5.3 Plot each against the global null distribution for all models

fMRI_plt

StrMRI_plt

6 Visualizing Important Variables

6.0.1 Reformat the variable names for interpretability (dropping harvard oxford region numbers)

# read in the region names harvard oxford key
regionNames = read.csv("../../../documentation/ho_key.csv",
                       stringsAsFactors = FALSE)
# store the oob auc for later use
FC_OOB_AUC = AUC(y_pred = predict(rf.list$FC, type = 'prob')[,"PI"], 
                 y_true = ifelse(rf.list$FC$y=="PI",1,0))
# first for the functional connectivity:
fcon.classify.imp = data.frame("var" = rownames(importance(rf.list[["FC"]])),
                               as.data.frame(importance(rf.list[["FC"]]))) %>%
  arrange(desc(MeanDecreaseAccuracy)) %>%
  mutate(bestVar = MeanDecreaseAccuracy >= quantile(MeanDecreaseAccuracy, probs = .975))
# add region names
fcon.classify.imp$region_1 = as.character(apply(fcon.classify.imp["var"], 1,
  function(x) {
    r1 = strsplit(x, split = "\\.X\\.")[[1]][1]
    r1_num = as.numeric(sub("^.+_","", r1))
    if (grepl("_cortical_", r1)) {
      return(regionNames$roiName[r1_num])
    } else
      return(regionNames$roiName[r1_num + 48])
  }
))
fcon.classify.imp$region_2 = as.character(apply(fcon.classify.imp["var"], 1,
  function(x) {
    r2 = strsplit(x, split = "\\.X\\.")[[1]][2]
    r2_num = as.numeric(sub("^.+_","", r2))
    if (grepl("_cortical_", r2)) {
      return(regionNames$roiName[r2_num])
    } else
      return(regionNames$roiName[r2_num + 48])
  }
))
# subset by the most important variables, using the quantiles (top 5% contained in bestVar)
best.fcon.imp = fcon.classify.imp %>%
  filter(bestVar) %>%
  rowwise() %>% # for element-wise pasting
  mutate(connection = paste0(region_1, "_WITH_", region_2))  %>%
  ungroup() %>%
  mutate(connection = gsub("[[:punct:]]","",gsub(" ", "", connection))) %>%
  # create a ratio importance score:
  mutate(LocalImpRatio=PI/COMP) %>%
  # scale the importances for aesthetics
  mutate_at(vars(COMP, PI, LocalImpRatio), ~((. - min(.)) / (max(.) - min(.)))) %>%
  mutate_at("MeanDecreaseAccuracy", 
            .funs = ~(((100+.) - min((100+.))) / (max((100+.)) - min((100+.))))*100
  )

resorted_names = unlist(arrange(
  best.fcon.imp, desc(MeanDecreaseAccuracy)
  )$connection)
# sort the names for plotting purposes:
best.fcon.imp = best.fcon.imp %>%
  mutate_at(vars(connection),~factor(., levels = rev(resorted_names)))

6.1 Functional Connectivity variable importance plot

# prepare graph
plt_FC_varimp = ggplot(data = best.fcon.imp) +
  geom_bar(
    aes(y = MeanDecreaseAccuracy, 
        x=connection,
        fill=(LocalImpRatio-min(LocalImpRatio))/(max(LocalImpRatio)-min(LocalImpRatio))), 
    stat = 'identity', color = "lightgrey") +
  scale_fill_gradient2(low = 'orange', 
                       high = 'blue', 
                       mid = 'white',
                       midpoint = .5, 
                       limit = c(0, 1),
                       name="Relative (more)\nimportance to\nPI over COMP") +
  labs(y="Importance to Classification Accuracy (scaled between 0 and 100)", x=NULL,
       title ="Variable Importances classifiying group from FC-MRI data",
       subtitle = paste0("OOB-AUC = ", round(FC_OOB_AUC,digits=4))) +
  coord_flip() + 
  theme(plot.title.position = "plot", plot.title = element_text(hjust=.5),
        plot.subtitle = element_text(hjust=.5))
### Show plot
plt_FC_varimp + theme(text = element_text(size = 20))

6.2 Structural data variable importance plot

Str_OOB_AUC = AUC(y_pred = predict(rf.list$Str, type="prob")[,"PI"], 
                 y_true = ifelse(rf.list$Str$y=="PI",1,0))

str.classify.imp = data.frame("var" = rownames(importance(rf.list[["Str"]])),
                               as.data.frame(importance(rf.list[["Str"]]))) %>%
  arrange(desc(MeanDecreaseAccuracy)) %>%
  # compute local importnce ratio (relative local importance)
  mutate(LocalImpRatio=PI/COMP) %>%
  # scale the importances for aesthetics
  mutate_at(vars(COMP, PI, LocalImpRatio), ~((. - min(.)) / (max(.) - min(.)))) %>%
  mutate_at("MeanDecreaseAccuracy", 
            .funs = ~(((100+.) - min((100+.))) / (max((100+.)) - min((100+.))))*100
  )

# re order the variable column for plotting purposes.
resorted_names = unlist(arrange(
  str.classify.imp, desc(MeanDecreaseAccuracy)
  )$var)
# sort the names for plotting purposes:
str.classify.imp = str.classify.imp %>%
  mutate_at(vars(var),~factor(., levels = rev(resorted_names)))

# prepare graph
plt_STR_varimp = ggplot(data = str.classify.imp) +
  geom_bar(aes(
    y = MeanDecreaseAccuracy, 
    x= var, 
    fill=(LocalImpRatio-min(LocalImpRatio))/(max(LocalImpRatio)-min(LocalImpRatio))), 
    stat = 'identity', color='lightgrey') +
  scale_fill_gradient2(low = 'orange', 
                       high = 'blue', 
                       mid = 'white',
                       midpoint = .5, 
                       limit = c(0, 1),
                       name="Relative (more)\nimportance to\nPI over COMP") +
  labs(y="Importance to Classification Accuracy (scaled between 0 and 100)", x=NULL,
       title ="Variable Importances classifiying group from Structural MRI",
       subtitle = paste0("OOB-AUC = ", round(Str_OOB_AUC,digits=4))) +
  coord_flip() + 
  theme(plot.title.position = "plot", plot.title = element_text(hjust=.5),
        plot.subtitle = element_text(hjust=.5))
  
## Show plot:
plt_STR_varimp

6.3 Structural variable importance plot with no whole brain vol.

Str_noWBV_OOB_AUC = AUC(y_pred = predict(rf.list$Str_noWBV, type="prob")[,"PI"], 
                        y_true = ifelse(rf.list$Str_noWBV$y=="PI",1,0))

Str_noWBV.classify.imp = data.frame("var" = rownames(importance(rf.list[["Str_noWBV"]])),
                               as.data.frame(importance(rf.list[["Str_noWBV"]]))) %>%
  arrange(desc(MeanDecreaseAccuracy)) %>%
  # compute local importnce ratio (relative local importance)
  mutate(LocalImpRatio=PI/COMP) %>%
  # scale the importances for aesthetics
  mutate_at(vars(COMP, PI, LocalImpRatio), ~((. - min(.)) / (max(.) - min(.)))) %>%
  mutate_at("MeanDecreaseAccuracy", 
            .funs = ~(((100+.) - min((100+.))) / (max((100+.)) - min((100+.))))*100
  )
# re order the variable column for plotting purposes.
resorted_names = unlist(arrange(
  Str_noWBV.classify.imp, desc(MeanDecreaseAccuracy)
  )$var)
# sort the names for plotting purposes:
Str_noWBV.classify.imp = Str_noWBV.classify.imp %>%
  mutate_at(vars(var),~factor(., levels = rev(resorted_names)))

# prepare graph
plt_Str_noWBV_varimp = ggplot(data = Str_noWBV.classify.imp) +
  geom_bar(aes(
    y = MeanDecreaseAccuracy, 
    x= var, 
    fill=(LocalImpRatio-min(LocalImpRatio))/(max(LocalImpRatio)-min(LocalImpRatio))), 
    stat = 'identity', color='lightgrey') +
  scale_fill_gradient2(low = 'orange', 
                       high = 'blue', 
                       mid = 'white',
                       midpoint = .5, 
                       limit = c(0, 1),
                       name="Relative (more)\nimportance to\nPI over COMP") +
  labs(y="Importance to Classification Accuracy (scaled between 0 and 100)", x=NULL,
       title ="Variable Importances classifiying group from Structural MRI (without whole-brain volume)",
       subtitle = paste0("OOB-AUC = ", round(Str_noWBV_OOB_AUC,digits=4))) +
  coord_flip() + 
  theme(plot.title = element_text(hjust = .5), plot.title.position = "plot",
        plot.subtitle = element_text(hjust=.5))
## Show plot:
plt_Str_noWBV_varimp

7 with the Functional dissimilarity data

# store the oob auc for later use
FD_OOB_AUC = AUC(y_pred = predict(rf.list$FD, type="prob")[,"PI"], 
                 y_true = ifelse(rf.list$FD$y=="PI",1,0))
# first for the functional connectiviy
fdis.classify.imp = data.frame("var" = rownames(importance(rf.list[["FD"]])),
                               as.data.frame(importance(rf.list[["FD"]]))) %>%
  arrange(desc(MeanDecreaseAccuracy)) %>%
  mutate(bestVar = MeanDecreaseAccuracy >= quantile(MeanDecreaseAccuracy, probs = .975))
# add region names
fdis.classify.imp$region_1 = as.character(apply(fdis.classify.imp["var"], 1,
  function(x) {
    r1 = strsplit(x, split = "\\.X\\.")[[1]][1]
    r1_num = as.numeric(sub("^.+_","", r1))
    if (grepl("_cortical_", r1)) {
      return(regionNames$roiName[r1_num])
    } else
      return(regionNames$roiName[r1_num + 48])
  }
))
fdis.classify.imp$region_2 = as.character(apply(fdis.classify.imp["var"], 1,
  function(x) {
    r2 = strsplit(x, split = "\\.X\\.")[[1]][2]
    r2_num = as.numeric(sub("^.+_","", r2))
    if (grepl("_cortical_", r2)) {
      return(regionNames$roiName[r2_num])
    } else
      return(regionNames$roiName[r2_num + 48])
  }
))
# subset by the most important variables, using the quantiles
best.fdis.imp = fdis.classify.imp %>%
  filter(bestVar) %>%
  rowwise() %>% # for element-wise pasting
  mutate(connection = paste0(region_1, "_WITH_", region_2))  %>%
  ungroup() %>%
  mutate(connection = gsub("[[:punct:]]","",gsub(" ", "", connection))) %>%
  # create a ratio importance score:
  mutate(LocalImpRatio=PI/COMP) %>%
  # scale the importances for aesthetics
  mutate_at(vars(COMP, PI, LocalImpRatio), ~((. - min(.)) / (max(.) - min(.)))) %>%
  mutate_at("MeanDecreaseAccuracy", 
            .funs = ~(((100+.) - min((100+.))) / (max((100+.)) - min((100+.))))*100
  )

resorted_names = unlist(arrange(
  best.fdis.imp, desc(MeanDecreaseAccuracy)
  )$connection)
# sort the names for plotting purposes:
best.fdis.imp = best.fdis.imp %>%
  mutate_at(vars(connection),~factor(., levels = rev(resorted_names)))

# prepare graph
plt_FD_varimp = ggplot(data = best.fdis.imp) +
  geom_bar(aes(
    y = MeanDecreaseAccuracy, 
    x=connection, 
    fill=(LocalImpRatio-min(LocalImpRatio))/(max(LocalImpRatio)-min(LocalImpRatio))), 
    stat = 'identity', color='lightgrey') +
  scale_fill_gradient2(low = 'orange', 
                       high = 'blue', 
                       mid = 'white',
                       midpoint = .5, 
                       limit = c(0, 1),
                       name="Relative (more)\nimportance to\nPI over COMP") +
  labs(y="Importance to Classification Accuracy (scaled between 0 and 100)", x=NULL,
       title ="Variable Importances classifiying group from FD-MRI data",
       subtitle = paste0("OOB-AUC = ", round(FD_OOB_AUC,digits=4))) +
  coord_flip() + 
  theme(plot.title.position = "plot", plot.title = element_text(hjust=.5),
        plot.subtitle = element_text(hjust=.5))
## Show plot::
plt_FD_varimp + theme(text = element_text(size = 20))

7.1 network graph based on important functional connectivity variables

library(igraph)
e_FC <- as.vector(
  t(as.matrix(
    fcon.classify.imp[fcon.classify.imp$bestVar, c("region_1","region_2")])
  )
)
# width will be the var imp from the model
# vertex size is the hub score from the graph
g_FC <- igraph::graph(edges = e_FC, directed = FALSE)
plot(g_FC, vertex.label.cex = 2, cex.main = 3,
     vertex.size= round(hub.score(g_FC)$vector*10),
     edge.width = round(fcon.classify.imp$MeanDecreaseAccuracy)+1)
title(main="Network of Most Important Pairwise Edges (functional connections)",
      cex.main=3)